{ "cells": [ { "cell_type": "code", "execution_count": 20, "id": "cdf8d1f5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration 1: RMSE on observed entries = 1.5143\n", "Iteration 2: RMSE on observed entries = 0.6462\n", "Iteration 3: RMSE on observed entries = 0.1315\n", "Iteration 4: RMSE on observed entries = 0.0358\n", "Iteration 5: RMSE on observed entries = 0.0114\n", "Iteration 6: RMSE on observed entries = 0.0039\n", "Iteration 7: RMSE on observed entries = 0.0014\n", "Iteration 8: RMSE on observed entries = 0.0005\n", "Iteration 9: RMSE on observed entries = 0.0002\n", "Iteration 10: RMSE on observed entries = 0.0001\n", "Iteration 11: RMSE on observed entries = 0.0000\n", "Iteration 12: RMSE on observed entries = 0.0000\n", "Iteration 13: RMSE on observed entries = 0.0000\n", "Iteration 14: RMSE on observed entries = 0.0000\n", "Iteration 15: RMSE on observed entries = 0.0000\n", "Iteration 16: RMSE on observed entries = 0.0000\n", "Iteration 17: RMSE on observed entries = 0.0000\n", "Iteration 18: RMSE on observed entries = 0.0000\n", "Iteration 19: RMSE on observed entries = 0.0000\n", "Iteration 20: RMSE on observed entries = 0.0000\n", "Iteration 21: RMSE on observed entries = 0.0000\n", "Iteration 22: RMSE on observed entries = 0.0000\n", "Iteration 23: RMSE on observed entries = 0.0000\n", "Iteration 24: RMSE on observed entries = 0.0000\n", "Iteration 25: RMSE on observed entries = 0.0000\n", "Iteration 26: RMSE on observed entries = 0.0000\n", "Iteration 27: RMSE on observed entries = 0.0000\n", "Iteration 28: RMSE on observed entries = 0.0000\n", "Iteration 29: RMSE on observed entries = 0.0000\n", "Iteration 30: RMSE on observed entries = 0.0000\n", "Iteration 31: RMSE on observed entries = 0.0000\n", "Iteration 32: RMSE on observed entries = 0.0000\n", "Iteration 33: RMSE on observed entries = 0.0000\n", "Iteration 34: RMSE on observed entries = 0.0000\n", "Iteration 35: RMSE on observed entries = 0.0000\n", "Iteration 36: RMSE on observed entries = 0.0000\n", "Iteration 37: RMSE on observed entries = 0.0000\n", "Iteration 38: RMSE on observed entries = 0.0000\n", "Iteration 39: RMSE on observed entries = 0.0000\n", "Iteration 40: RMSE on observed entries = 0.0000\n", "Iteration 41: RMSE on observed entries = 0.0000\n", "Iteration 42: RMSE on observed entries = 0.0000\n", "Iteration 43: RMSE on observed entries = 0.0000\n", "Iteration 44: RMSE on observed entries = 0.0000\n", "Iteration 45: RMSE on observed entries = 0.0000\n", "Iteration 46: RMSE on observed entries = 0.0000\n", "Iteration 47: RMSE on observed entries = 0.0000\n", "Iteration 48: RMSE on observed entries = 0.0000\n", "Iteration 49: RMSE on observed entries = 0.0000\n", "Iteration 50: RMSE on observed entries = 0.0000\n", "Iteration 51: RMSE on observed entries = 0.0000\n", "Iteration 52: RMSE on observed entries = 0.0000\n", "Iteration 53: RMSE on observed entries = 0.0000\n", "Iteration 54: RMSE on observed entries = 0.0000\n", "Iteration 55: RMSE on observed entries = 0.0000\n", "Iteration 56: RMSE on observed entries = 0.0000\n", "Iteration 57: RMSE on observed entries = 0.0000\n", "Iteration 58: RMSE on observed entries = 0.0000\n", "Iteration 59: RMSE on observed entries = 0.0000\n", "Iteration 60: RMSE on observed entries = 0.0000\n", "Iteration 61: RMSE on observed entries = 0.0000\n", "Iteration 62: RMSE on observed entries = 0.0000\n", "Iteration 63: RMSE on observed entries = 0.0000\n", "Iteration 64: RMSE on observed entries = 0.0000\n", "Iteration 65: RMSE on observed entries = 0.0000\n", "Iteration 66: RMSE on observed entries = 0.0000\n", "Iteration 67: RMSE on observed entries = 0.0000\n", "Iteration 68: RMSE on observed entries = 0.0000\n", "Iteration 69: RMSE on observed entries = 0.0000\n", "Iteration 70: RMSE on observed entries = 0.0000\n", "Iteration 71: RMSE on observed entries = 0.0000\n", "Iteration 72: RMSE on observed entries = 0.0000\n", "Iteration 73: RMSE on observed entries = 0.0000\n", "Iteration 74: RMSE on observed entries = 0.0000\n", "Iteration 75: RMSE on observed entries = 0.0000\n", "Iteration 76: RMSE on observed entries = 0.0000\n", "Iteration 77: RMSE on observed entries = 0.0000\n", "Iteration 78: RMSE on observed entries = 0.0000\n", "Iteration 79: RMSE on observed entries = 0.0000\n", "Iteration 80: RMSE on observed entries = 0.0000\n", "Iteration 81: RMSE on observed entries = 0.0000\n", "Iteration 82: RMSE on observed entries = 0.0000\n", "Iteration 83: RMSE on observed entries = 0.0000\n", "Iteration 84: RMSE on observed entries = 0.0000\n", "Iteration 85: RMSE on observed entries = 0.0000\n", "Iteration 86: RMSE on observed entries = 0.0000\n", "Iteration 87: RMSE on observed entries = 0.0000\n", "Iteration 88: RMSE on observed entries = 0.0000\n", "Iteration 89: RMSE on observed entries = 0.0000\n", "Iteration 90: RMSE on observed entries = 0.0000\n", "Iteration 91: RMSE on observed entries = 0.0000\n", "Iteration 92: RMSE on observed entries = 0.0000\n", "Iteration 93: RMSE on observed entries = 0.0000\n", "Iteration 94: RMSE on observed entries = 0.0000\n", "Iteration 95: RMSE on observed entries = 0.0000\n", "Iteration 96: RMSE on observed entries = 0.0000\n", "Iteration 97: RMSE on observed entries = 0.0000\n", "Iteration 98: RMSE on observed entries = 0.0000\n", "Iteration 99: RMSE on observed entries = 0.0000\n", "Iteration 100: RMSE on observed entries = 0.0000\n", "\n", "Final RMSE on all entries: 0.0000\n" ] } ], "source": [ "import numpy as np\n", "\n", "np.random.seed(0) # For reproducibility\n", "\n", "\n", "# Parameters\n", "n = 100 # number of rows (and columns) of M\n", "r = 5 # rank\n", "num_obs = n * 40 # total number of observed entries\n", "# with fewer observation, gets trapped in local minima\n", "# num_obs = n * 30 \n", "\n", "# -----------------------------\n", "# 1. Create a ground-truth low-rank matrix M\n", "# -----------------------------\n", "M = np.random.randn(n, r) @ np.random.randn(r, n)\n", "# M is n x n\n", "\n", "# -----------------------------\n", "# 2. Choose a random subset Omega of the entries\n", "# -----------------------------\n", "sampled_indices = np.random.choice(n * n, num_obs, replace=False)\n", "Omega = np.column_stack(np.unravel_index(sampled_indices, (n, n)))\n", "\n", "# Build index lists: for each row i, list the observed columns; for each column j, list the observed rows.\n", "Omega_I = [[j for i, j in Omega if i == row] for row in range(n)]\n", "Omega_J = [[i for i, j in Omega if j == col] for col in range(n)]\n", "\n", "\n", "# -----------------------------\n", "# 2.5. Checker for empty observations in Omega_I or Omega_J\n", "# -----------------------------\n", "empty_rows = [i for i, obs in enumerate(Omega_I) if len(obs) == 0]\n", "empty_cols = [j for j, obs in enumerate(Omega_J) if len(obs) == 0]\n", "\n", "if empty_rows:\n", " print(\"Warning: The following row indices have no observations:\", empty_rows)\n", "\n", "if empty_cols:\n", " print(\"Warning: The following column indices have no observations:\", empty_cols)\n", "\n", " \n", "# -----------------------------\n", "# 3. Initialize the factors L and R with small noise (sigma=0.01)\n", "# -----------------------------\n", "L = 0.1 * np.random.randn(n, r) \n", "R = 0.1 * np.random.randn(r, n)\n", "# Alternating minimization will not work with\n", "# L=np.zeros((n, r)) \n", "# R=np.zeros((r, n)) \n", "\n", "\n", "# -----------------------------\n", "# 4. Alternating Minimization Updates\n", "# -----------------------------\n", "K = 100 # number of iterations\n", "for it in range(K):\n", " \n", " for i in range(n):\n", " R_sel = R[:, Omega_I[i]]\n", " A = R_sel @ R_sel.T\n", " b = R_sel @ M[i, Omega_I[i]]\n", " L[i, :] = np.linalg.inv(A) @ b\n", "\n", " \n", " for j in range(n):\n", " L_sel = L[Omega_J[j], :]\n", " A = L_sel.T @ L_sel\n", " b = L_sel.T @ M[Omega_J[j], j]\n", " R[:, j] = np.linalg.inv(A) @ b\n", "\n", " \n", " #Compute the error on the observed entries\n", " err = 0.0\n", " for i, j in Omega:\n", " err += (M[i, j] - L[i, :] @ R[:, j]) ** 2\n", " rmse = np.sqrt(err / num_obs)\n", " print(f\"Iteration {it+1:3d}: RMSE on observed entries = {rmse:.4f}\")\n", "\n", "# Compare the recovered M_hat = L @ R with the true M.\n", "M_hat = L @ R\n", "total_rmse = np.sqrt(np.mean((M - M_hat)**2))\n", "print(f\"\\nFinal RMSE on all entries: {total_rmse:.4f}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }